#
# This source file is part of the EdgeDB open source project.
#
# Copyright 2020-present MagicStack Inc. and the EdgeDB authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import annotations

import asyncio
import json
import textwrap
from typing import Mapping

from edb.tools.edb import edbcommands


async def run():
    import asyncpg

    schemas_by_version = {}

    # docker run -p 5413:5432 --rm -e POSTGRES_PASSWORD=pass postgres:13
    # docker run -p 5414:5432 --rm -e POSTGRES_PASSWORD=pass postgres:14
    # docker run -p 5415:5432 --rm -e POSTGRES_PASSWORD=pass postgres:15
    # docker run -p 5416:5432 --rm -e POSTGRES_PASSWORD=pass postgres:16
    # docker run -p 5417:5432 --rm -e POSTGRES_PASSWORD=pass postgres:17
    for port in [5413, 5414, 5415, 5416, 5417]:
        await query_instance(
            await asyncpg.connect(
                host='localhost',
                database='postgres',
                user='postgres',
                password='pass',
                port=str(port),
            ),
            schemas_by_version,
        )

    with open('./edb/pgsql/resolver/sql_introspection.py', 'w') as file:
        print_header(file)
        print_schema(file, "information_schema", schemas_by_version)
        print_schema(file, "pg_catalog", schemas_by_version)

    print('Done.')


async def query_instance(c, schemas_by_version):
    import asyncpg

    assert isinstance(c, asyncpg.Connection)

    [[version_string]] = await c.fetch('SELECT version()')
    assert isinstance(version_string, str)
    version_string = version_string.removeprefix('PostgreSQL ')
    version_major = int(version_string.split(' ')[0].split('.')[0])

    print(f'querying PostgreSQL {version_major}')

    [[res]] = await c.fetch(
        r'''
    WITH
    columns AS (
        SELECT
            table_schema,
            table_name,
            TO_JSON(ARRAY [
                column_name, COALESCE(domain_name, data_type)
            ]) AS col
        FROM information_schema.columns
        WHERE (
            table_schema = 'information_schema' AND table_name NOT LIKE '\_pg%'
        ) OR (
            table_schema = 'pg_catalog'
        )
        ORDER BY table_schema, table_name, ordinal_position
    ),
    tables AS (
        SELECT table_schema, table_name, JSON_AGG(col) as cols
        FROM columns
        GROUP BY table_schema, table_name
    ),
    schemas AS (
        SELECT table_schema, JSON_OBJECT_AGG(table_name, cols) as tables
        FROM tables
        GROUP BY table_schema
    )
    SELECT JSON_OBJECT_AGG(table_schema, tables)
    FROM schemas
    '''
    )

    schemas_by_version[version_major] = json.loads(res)


def print_header(f):
    print(
        textwrap.dedent(
            '''
        # AUTOGENERATED FROM _localdev postgres instance WITH
        #    $ edb gen-sql-introspection

        """Declarations of information schema and pg_catalog"""

        from typing import Tuple, Dict, List

        ColumnName = str
        ColumnType = str | None
        '''
        )[1:],
        file=f,
    )


def print_schema(
    f,
    schema_name: str,
    schemas_by_version: Mapping[
        int, Mapping[str, Mapping[str, list[tuple[str, str]]]]
    ],
):
    """
    Generates Python dict source for tables of a given PostgreSQL schema and
    writes it into a file. Param `tables` can contain data for more than one
    version of PostgreSQL. This function will generate code for the latest
    PostgreSQL version and search previous versions to determine the first
    version that contains each column.
    """
    print(f'Code generation of schema "{schema_name}"')

    version_latest = max(iter(schemas_by_version.keys()))
    versions_desc = list(schemas_by_version.keys())
    versions_desc.sort(reverse=True)
    versions_desc = versions_desc[1:]

    schemas_latest = schemas_by_version[version_latest]
    tables_latest = schemas_latest[schema_name]

    typ = ': Dict[str, List[Tuple[ColumnName, ColumnType, int]]]'
    print(schema_name.upper() + typ + " = {", file=f)
    for index, (table, columns) in enumerate(tables_latest.items()):
        print(f'    "{table}": [', file=f)
        for [col_name, col_typ] in columns:

            ver_since = version_latest
            for v in versions_desc:
                schema = schemas_by_version.get(v)
                assert schema
                tbls = schema.get(schema_name)
                assert tbls
                tbl = tbls.get(table, None)
                if tbl is None:
                    break
                c = next((True for c, _ in tbl if c == col_name), False)
                if not c:
                    break
                ver_since = v

            if col_typ == "ARRAY" or col_typ.startswith("any"):
                col_typ = "None"
            else:
                col_typ = col_typ.replace('"', '\\"')
                col_typ = f'"{col_typ}"'
            print(f'        ("{col_name}", {col_typ}, {ver_since}),', file=f)

        last = index == len(tables_latest) - 1
        comma = ',' if not last else ''
        print(f'    ]{comma}', file=f)
    print('}', file=f)


@edbcommands.command("gen-sql-introspection")
def main():
    asyncio.run(run())
